functionality to use pretrained pie model #168
Merged
+26
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds functionality to use a pretrained PyTorch-IE model that allows updating weights of base transformer model. Additionally, it provides an optional parameter,
pretrained_pie_model_prefix_mapping
, to filter and specify which model layers should be updated. Furthermore, this adds the optional parameterpretrained_pie_model_prefix_mapping
that allows to restrict the loading to a subset of parameters via their prefixes.For instance, to load the weights from a PIE token classification model for further training on conll2003:
python src/train.py experiment=conll2003 pretrained_pie_model_path=path/to/pretrained/pie/model "+pretrained_pie_model_prefix_mapping={model.model:model.model}"
Note that we use
model.model:model.model
as prefix mapping which will result in only loading the weights from the base transformer model, but not the classification head. This is useful to fine-tune on data with a different set of labels.In addition, this also adds imports for all models and task modules from the pie_modules package to the train script because these are required when loading a such a model vie
pytorch_ie.AutoModel
.